#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file    demo.py
@brief
@details
@author  Shivelino
@date    2023-12-23 19:10
@version 0.0.1

@par Copyright(c):
@par todo:
@par history:
"""
import torch

import argparse
import torchvision.transforms as transforms

from nets import get_model
from utils import get_device

from PyQt5 import QtCore, QtGui, QtWidgets
import sys
from PyQt5.QtWidgets import QApplication, QLabel, QMainWindow
from PyQt5.QtGui import QPainter, QPen
from PyQt5.QtCore import Qt
import numpy as np
import cv2
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

def qimage2opencv(image):
    """QImage to opencv image(numpy array)"""
    width = image.width()
    height = image.height()
    image_opencv = np.zeros((height, width, 3), dtype=np.uint8)
    for y in range(height):
        for x in range(width):
            pixel_color = image.pixelColor(x, y)
            image_opencv[y, x] = [pixel_color.blue(), pixel_color.green(), pixel_color.red()]  # BGR
    return image_opencv


class Inferior(object):
    def __init__(self, opt):
        # load model
        self.device = get_device()
        self.model = get_model(opt.model).to(self.device)
        self.model.load_state_dict(torch.load(f'model/model_{opt.model}.pth'))
        self.model.eval()
        self.softmax = torch.nn.Softmax(dim=0)
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    def infer(self, img_np):  # img_np: 1,28,28
        img_tensor = self.transform(img_np).unsqueeze_(0)  # to tensor
        # infer
        result, confidence = -1, -1
        with torch.no_grad():
            img_tensor = img_tensor.to(self.device)
            outputs = self.model(img_tensor).to("cpu")
            output = self.softmax(outputs[0])
            result = int(torch.argmax(output))
            confidence = output[result]
        # print(f"Hand-writing number: {result}; confidence: {confidence * 100: .2f}%")
        return result, confidence


class MainWindow(QMainWindow):
    def __init__(self, inferior):
        super().__init__()
        self.inferior = inferior  # 推理器

        self.points = []
        self.setupUi(self)

    def setupUi(self, mainwindow):
        mainwindow.setObjectName("mainwindow")
        mainwindow.resize(1280, 720)
        self.centralwidget = QtWidgets.QWidget(mainwindow)
        self.centralwidget.setObjectName("centralwidget")

        self.boarder = QLabel(self.centralwidget)
        self.boarder.setGeometry(QtCore.QRect(10, 10, 700, 700))
        self.boarder.setMinimumSize(QtCore.QSize(700, 700))
        self.boarder.setMaximumSize(QtCore.QSize(700, 700))
        self.boarder.setCursor(QtGui.QCursor(QtCore.Qt.CrossCursor))
        self.boarder.setObjectName("boarder")
        self.boarder.setStyleSheet("border: 1px solid black;")
        self.boarder.setFixedSize(700, 700)
        self.boarder.setPixmap(self.boarder.grab())
        self.blank_board = self.boarder.pixmap()

        self.result = QtWidgets.QLabel(self.centralwidget)
        self.result.setGeometry(QtCore.QRect(720, 180, 550, 160))
        font = QtGui.QFont()
        font.setPointSize(64)
        font.setBold(True)
        self.result.setFont(font)
        self.result.setAlignment(QtCore.Qt.AlignCenter)
        self.result.setObjectName("result")
        self.result.setStyleSheet("border: 1px solid black;")

        self.text1 = QtWidgets.QLabel(self.centralwidget)
        self.text1.setGeometry(QtCore.QRect(720, 10, 550, 160))
        font = QtGui.QFont()
        font.setFamily("Microsoft YaHei UI")
        font.setPointSize(64)
        font.setBold(True)
        self.text1.setFont(font)
        self.text1.setAlignment(QtCore.Qt.AlignCenter)
        self.text1.setObjectName("text1")

        self.confidence = QtWidgets.QLabel(self.centralwidget)
        self.confidence.setGeometry(QtCore.QRect(720, 550, 550, 160))
        font = QtGui.QFont()
        font.setPointSize(64)
        font.setBold(True)
        self.confidence.setFont(font)
        self.confidence.setAlignment(QtCore.Qt.AlignCenter)
        self.confidence.setObjectName("confidence")
        self.confidence.setStyleSheet("border: 1px solid black;")

        self.text2 = QtWidgets.QLabel(self.centralwidget)
        self.text2.setGeometry(QtCore.QRect(720, 370, 550, 160))
        font = QtGui.QFont()
        font.setFamily("Microsoft YaHei UI")
        font.setPointSize(64)
        font.setBold(True)
        self.text2.setFont(font)
        self.text2.setAlignment(QtCore.Qt.AlignCenter)
        self.text2.setObjectName("text2")
        mainwindow.setCentralWidget(self.centralwidget)

        self.retranslateUi(mainwindow)
        QtCore.QMetaObject.connectSlotsByName(mainwindow)

    def retranslateUi(self, mainwindow):
        _translate = QtCore.QCoreApplication.translate
        mainwindow.setWindowTitle(_translate("mainwindow", "手写数字识别演示程序"))
        self.text1.setText(_translate("mainwindow", "识别结果"))
        self.text2.setText(_translate("mainwindow", "识别置信度"))

    def paintEvent(self, event):
        painter = QPainter(self.boarder.pixmap())
        pen = QPen()
        pen.setWidth(98)
        pen.setColor(Qt.black)
        painter.setPen(pen)

        for i in range(1, len(self.points)):
            painter.drawLine(self.points[i - 1], self.points[i])
        self.update()

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.points = [event.pos()]
        elif event.button() == Qt.RightButton:
            result, confidence = self.infer()
            self.result.setText(f"{result}")
            self.confidence.setText(f"{confidence: .4f}")

    def mouseMoveEvent(self, event):
        pos = event.pos()
        pos.setX(pos.x() - 10)
        pos.setY(pos.y() - 10)
        self.points.append(pos)

    def mouseDoubleClickEvent(self, event):
        if event.button() == Qt.LeftButton:
            # clear board
            painter = QPainter(self.boarder.pixmap())
            painter.eraseRect(self.boarder.rect())
            self.points.clear()
            # clear text edit
            self.result.setText(f"")
            self.confidence.setText(f"")

    def infer(self):
        # convert img
        img_np = qimage2opencv(self.boarder.pixmap().toImage())
        img_np = cv2.resize(img_np, (28, 28))
        img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
        img_np = cv2.bitwise_not(img_np)
        return self.inferior.infer(img_np)  # infer


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="lenet", help='model')
    custom_inferior = Inferior(parser.parse_args())

    app = QApplication(sys.argv)
    window = MainWindow(custom_inferior)
    window.show()
    sys.exit(app.exec_())
